{
 "metadata": {
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.6.10-final"
  },
  "orig_nbformat": 2,
  "kernelspec": {
   "name": "python3",
   "display_name": "Python 3"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2,
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# 逻辑回归\n",
    "通过线性回归模型，我们已经了解过一个模型基本构成，其实模型就是一个数学函数簇，通过数据只是学习函数簇的参数得到具体的函数，从本篇开始学习逻辑回归，在学习逻辑回归之前，我们大概聊一下模型的分类。\n",
    "通常来说，模型可以分如下几类：\n",
    "+ 监督学习模型：训练样本不包含答案\n",
    "+ 非监督学习模型：训练样本包含答案\n",
    "  - 回归模型：答案是连续的实数\n",
    "  - 分类模型：答案是离散的整数，分类模型又分为二分类和多分类模型\n",
    "\n",
    "那么回归模型属于监督学习的回归模型，而本篇介绍的逻辑回归，就是属于监督学习中的二分类模型。\n",
    "\n",
    "\n",
    "## 逻辑回归模型\n",
    "下面我们快速撕开逻辑回归模型本质\n",
    "+ 数据：根据问题来的，比如预测明天会不会下雨，我们会收集很多下雨天和非下雨天之前的天气情况，比如一天前的湿度、温度、前一周的平均下雨天数等【各种变量】，还会收集当天的天气【答案或者目标】，这个就是训练样本\n",
    "+ 模型：逻辑回归就是规定函数形式就是：y_hat = sigmod(-(b + a1\\*x1 + a2\\*x2 + an\\*xm))，sigmod 是一个将任意实数映射到 0-1 的概率的一个转换函数，sigmod(x)=1/(1+exp(-x))，sigmod(x) 的导数 = sigmod(1)\\*(1-sigmod(x))\n",
    "+ 参数：b a1 a2 am 就是参数\n",
    "+ 损失函数：z = sum(-(y\\*log(y_hat) + (1-y)\\*log(1-y_hat)))/n，n 是样本数量\n",
    "+ 优化方法：梯度下降等，得益于 PyTorch 自动求梯度，我们不再需要显示求出损失函数的导数\n",
    "\n",
    "可以看到，逻辑回归的模型一部分和线性回归很相似，这也是它名字中包含回归的原因。\n",
    "\n",
    "下面我们来用 PyTorch 来构建整个逻辑回归模型。\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 1. 数据准备\n",
    "我们这里选择直接构造一个训练集"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {
    "tags": []
   },
   "outputs": [
    {
     "output_type": "stream",
     "name": "stdout",
     "text": "X.shape torch.Size([10, 2])\ntrue_Y.shape torch.Size([10])\ntensor([0., 0., 1., 1., 0., 0., 1., 0., 1., 1.])\n"
    }
   ],
   "source": [
    "import torch\n",
    "import numpy as np\n",
    "X = torch.tensor([[2., 1.],\n",
    "    [2., 2.],\n",
    "    [5., 4.],\n",
    "    [4., 5.],\n",
    "    [2., 3.],\n",
    "    [3., 2.],\n",
    "    [6., 5.],\n",
    "    [4., 1.],\n",
    "    [6., 3.],\n",
    "    [7, 4]], dtype=torch.float\n",
    ")\n",
    "print('X.shape', X.shape)\n",
    "true_Y = torch.tensor([0, 0, 1, 1, 0, 0, 1, 0, 1, 1], dtype=torch.float)\n",
    "print('true_Y.shape', true_Y.shape)\n",
    "print(true_Y)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 2. 定义基本工具函数\n",
    "+ 模型函数，指定样本和权重，对应的输出\n",
    "+ 初始化参数函数，初始化所有必须的参数\n",
    "+ 损失函数，初始化衡量误差的函数\n",
    "+ 样本获取函数，如何批量获取样本"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {},
   "outputs": [],
   "source": [
    "# 定义回归模型\n",
    "def logic_reg_model(weight, bias, input, batch_size):\n",
    "    return torch.sigmoid(torch.mm(input, weight).view(batch_size) + bias)\n",
    "\n",
    "def params_init():\n",
    "    p_weights = torch.randn(2, 1, dtype=torch.float, requires_grad=True)\n",
    "    p_bias = torch.zeros(1, dtype=torch.float, requires_grad=True)\n",
    "    return p_weights, p_bias\n",
    "\n",
    "def loss_func(true_Y, hat_Y):\n",
    "    error = -(true_Y*torch.log(hat_Y) + (1-true_Y)*torch.log(hat_Y))\n",
    "    return error ** 2\n",
    "\n",
    "def sample_batchs(X, true_Y, sample_nums, batch_size):\n",
    "    res = []\n",
    "    inds = list(range(0, sample_nums))\n",
    "    np.random.shuffle(inds)\n",
    "    cur_ind = 0\n",
    "    while cur_ind + batch_size < sample_nums:\n",
    "        keep_inds = inds[cur_ind:cur_ind + batch_size]\n",
    "        res.append((torch.index_select(X, 0, torch.tensor(keep_inds, dtype=torch.int64)), \n",
    "            torch.index_select(true_Y, 0, torch.tensor(keep_inds, dtype=torch.int64))))\n",
    "        cur_ind += batch_size\n",
    "    keep_inds = inds[cur_ind:cur_ind + batch_size]\n",
    "    if keep_inds:\n",
    "        res.append((torch.index_select(X, 0, torch.tensor(keep_inds, dtype=torch.int64)), \n",
    "            torch.index_select(true_Y, 0, torch.tensor(keep_inds, dtype=torch.int64))))\n",
    "    return res"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 3. 开始模型训练\n",
    "采用微批梯度下降法，需要指定以下参数：\n",
    "+ epoch_nums: 整个样本迭代训练几次\n",
    "+ batch_nums: 每次微批的数据量大小\n",
    "+ X: 输入的数据 X\n",
    "+ ture_Y: 样本真正的 Y\n",
    "+ step_ratio: 迭代步长"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 16,
   "metadata": {
    "tags": []
   },
   "outputs": [
    {
     "output_type": "stream",
     "name": "stdout",
     "text": "tensor([[ 0.8588],\n        [-0.2786]], requires_grad=True)\ntensor([0.], requires_grad=True)\nepoch=20, loss=0.006928201764822006\np_weight=tensor([[ 1.0085],\n        [-0.1180]], requires_grad=True)\np_bias=tensor([0.0699], requires_grad=True)\nepoch=40, loss=0.004014096688479185\np_weight=tensor([[ 1.0698],\n        [-0.0553]], requires_grad=True)\np_bias=tensor([0.0995], requires_grad=True)\nepoch=60, loss=0.002837373409420252\np_weight=tensor([[ 1.1095],\n        [-0.0157]], requires_grad=True)\np_bias=tensor([0.1187], requires_grad=True)\nepoch=80, loss=0.0021992085967212915\np_weight=tensor([[1.1390],\n        [0.0131]], requires_grad=True)\np_bias=tensor([0.1330], requires_grad=True)\nepoch=100, loss=0.0017981810960918665\np_weight=tensor([[1.1625],\n        [0.0359]], requires_grad=True)\np_bias=tensor([0.1445], requires_grad=True)\nepoch=120, loss=0.0015225138049572706\np_weight=tensor([[1.1821],\n        [0.0546]], requires_grad=True)\np_bias=tensor([0.1541], requires_grad=True)\nepoch=140, loss=0.00132126419339329\np_weight=tensor([[1.1990],\n        [0.0704]], requires_grad=True)\np_bias=tensor([0.1623], requires_grad=True)\nepoch=160, loss=0.001167768961749971\np_weight=tensor([[1.2137],\n        [0.0842]], requires_grad=True)\np_bias=tensor([0.1696], requires_grad=True)\nepoch=180, loss=0.0010467563988640904\np_weight=tensor([[1.2269],\n        [0.0964]], requires_grad=True)\np_bias=tensor([0.1760], requires_grad=True)\nepoch=200, loss=0.0009488696232438087\np_weight=tensor([[1.2387],\n        [0.1073]], requires_grad=True)\np_bias=tensor([0.1818], requires_grad=True)\n"
    }
   ],
   "source": [
    "epoch_nums = 200\n",
    "batch_nums = 2\n",
    "step_ratio = 0.03\n",
    "sample_nums = 10\n",
    "p_weight, p_bias = params_init()\n",
    "print(p_weight)\n",
    "print(p_bias)\n",
    "\n",
    "for epoch in range(0, epoch_nums):\n",
    "    batchs = sample_batchs(X, true_Y, sample_nums, batch_nums)\n",
    "    for batch in batchs:\n",
    "        b_X, b_true_Y = batch\n",
    "        batch_size = b_true_Y.shape[0]\n",
    "        b_hat_Y = logic_reg_model(p_weight, p_bias, b_X, batch_size)\n",
    "        loss = loss_func(b_true_Y, b_hat_Y).sum()\n",
    "\n",
    "        loss.backward()\n",
    "        p_weight.data -= step_ratio*p_weight.grad/batch_size\n",
    "        p_bias.data -= step_ratio*p_bias.grad/batch_size\n",
    "\n",
    "        p_weight.grad.data.zero_()\n",
    "        p_bias.grad.data.zero_()\n",
    "    if (epoch+1) % 20 == 0:\n",
    "        with torch.no_grad():\n",
    "            hat_Y = logic_reg_model(p_weight, p_bias, X, sample_nums)\n",
    "            loss = loss_func(true_Y, hat_Y)\n",
    "            print(\"epoch={}, loss={}\".format(epoch+1, loss.mean()))\n",
    "            print(\"p_weight={}\".format(p_weight))\n",
    "            print(\"p_bias={}\".format(p_bias))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 4. 逻辑回归高级版本实现\n",
    "通过 PyTorch Lightning 来实现， https://github.com/PyTorchLightning/pytorch-lightning"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 19,
   "metadata": {},
   "outputs": [],
   "source": [
    "# import nessasary lib\n",
    "import os\n",
    "import torch\n",
    "from torch.nn import functional as F\n",
    "from torch.utils.data import DataLoader\n",
    "from torch.utils.data import TensorDataset\n",
    "import pytorch_lightning as pl\n",
    "import numpy as np"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 32,
   "metadata": {},
   "outputs": [],
   "source": [
    "class LogicRegModel(pl.LightningModule):\n",
    "\n",
    "    def __init__(self):\n",
    "        super(LogicRegModel, self).__init__()\n",
    "        # 定义模型结构\n",
    "        self.l1 = torch.nn.Linear(2, 1)\n",
    "\n",
    "    def forward(self, x):\n",
    "        # 必须：定义模型\n",
    "        return torch.sigmoid(self.l1(x))\n",
    "\n",
    "    def training_step(self, batch, batch_nb):\n",
    "        # 必须提供：定于训练过程\n",
    "        x, y = batch\n",
    "        y_hat = self(x)\n",
    "        loss = F.binary_cross_entropy(y_hat, y)\n",
    "        tensorboard_logs = {'train_loss': loss}\n",
    "        return {'loss': loss, 'log': tensorboard_logs}\n",
    "\n",
    "    def validation_step(self, batch, batch_nb):\n",
    "        # 可选提供：定义验证过程\n",
    "        x, y = batch\n",
    "        y_hat = self(x)\n",
    "        \n",
    "        return {'val_loss': F.binary_cross_entropy(y_hat, y)}\n",
    "\n",
    "    def validation_epoch_end(self, outputs):\n",
    "        # 可选提供：定义验证过程\n",
    "        avg_loss = torch.stack([x['val_loss'] for x in outputs]).mean()\n",
    "        tensorboard_logs = {'val_loss': avg_loss}\n",
    "        return {'val_loss': avg_loss, 'log': tensorboard_logs}\n",
    "\n",
    "    def test_step(self, batch, batch_nb):\n",
    "        # 可选提供：定义测试过程\n",
    "        x, y = batch\n",
    "        y_hat = self(x)\n",
    "        return {'test_loss': F.binary_cross_entropy(y_hat, y)}\n",
    "\n",
    "    def test_epoch_end(self, outputs):\n",
    "        # 可选提供：定义测试过程\n",
    "        avg_loss = torch.stack([x['test_loss'] for x in outputs]).mean()\n",
    "        logs = {'test_loss': avg_loss}\n",
    "        return {'test_loss': avg_loss, 'log': logs, 'progress_bar': logs}\n",
    "\n",
    "    def configure_optimizers(self):\n",
    "        # 必须提供：定义优化器\n",
    "        # can return multiple optimizers and learning_rate schedulers\n",
    "        # (LBFGS it is automatically supported, no need for closure function)\n",
    "        return torch.optim.SGD(self.parameters(), lr=0.1)\n",
    "\n",
    "    def gen_data_loader(self, shuffle, batch_size):\n",
    "        X = torch.tensor([[2., 1.],\n",
    "            [2., 2.],\n",
    "            [5., 4.],\n",
    "            [4., 5.],\n",
    "            [2., 3.],\n",
    "            [3., 2.],\n",
    "            [6., 5.],\n",
    "            [4., 1.],\n",
    "            [6., 3.],\n",
    "            [7, 4]], dtype=torch.float\n",
    "        )\n",
    "        Y = torch.tensor([0, 0, 1, 1, 0, 0, 1, 0, 1, 1], dtype=torch.float)\n",
    "        # 先转换成 torch 能识别的 Dataset\n",
    "        torch_dataset = TensorDataset(X, Y)\n",
    "\n",
    "        # 把 dataset 放入 DataLoader\n",
    "        loader = DataLoader(\n",
    "            dataset=torch_dataset,      # torch TensorDataset format\n",
    "            batch_size=batch_size,      # mini batch size\n",
    "            shuffle=shuffle,            # 要不要打乱数据 (打乱比较好)\n",
    "            num_workers=4,              # 多线程来读数据\n",
    "        )\n",
    "        return loader\n",
    "\n",
    "    def train_dataloader(self):\n",
    "        # 必须提供：提供训练数据集\n",
    "        return self.gen_data_loader(True, 2)\n",
    "\n",
    "    def val_dataloader(self):\n",
    "        # 可选提供：提供验证数据集\n",
    "        return self.gen_data_loader(False, 10)\n",
    "\n",
    "    def test_dataloader(self):\n",
    "        # 可选提供：提供测试数据集\n",
    "        return self.gen_data_loader(False, 10)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 34,
   "metadata": {
    "tags": []
   },
   "outputs": [
    {
     "output_type": "stream",
     "name": "stderr",
     "text": "GPU available: False, used: False\nTPU available: False, using: 0 TPU cores\n\n  | Name | Type   | Params\n--------------------------------\n0 | l1   | Linear | 3     \n"
    },
    {
     "output_type": "display_data",
     "data": {
      "text/plain": "HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Training', layout=Layout(flex='2'), max…",
      "application/vnd.jupyter.widget-view+json": {
       "version_major": 2,
       "version_minor": 0,
       "model_id": "8c4c36129f7643008d2d1cbe0a7d1981"
      }
     },
     "metadata": {}
    },
    {
     "output_type": "display_data",
     "data": {
      "text/plain": "HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…",
      "application/vnd.jupyter.widget-view+json": {
       "version_major": 2,
       "version_minor": 0,
       "model_id": "1a331676e8504928a0c0bffd5422b8b5"
      }
     },
     "metadata": {}
    },
    {
     "output_type": "display_data",
     "data": {
      "text/plain": "HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…",
      "application/vnd.jupyter.widget-view+json": {
       "version_major": 2,
       "version_minor": 0,
       "model_id": "34f725eed4e54e38b4af8c5e84ec06e0"
      }
     },
     "metadata": {}
    },
    {
     "output_type": "display_data",
     "data": {
      "text/plain": "HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…",
      "application/vnd.jupyter.widget-view+json": {
       "version_major": 2,
       "version_minor": 0,
       "model_id": "aa978508f51143c1b99b5e0cf43e017d"
      }
     },
     "metadata": {}
    },
    {
     "output_type": "display_data",
     "data": {
      "text/plain": "HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…",
      "application/vnd.jupyter.widget-view+json": {
       "version_major": 2,
       "version_minor": 0,
       "model_id": "731213e0b304422ab8161dc1331c0464"
      }
     },
     "metadata": {}
    },
    {
     "output_type": "display_data",
     "data": {
      "text/plain": "HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…",
      "application/vnd.jupyter.widget-view+json": {
       "version_major": 2,
       "version_minor": 0,
       "model_id": "e657885e9f4346d182af957e6a858c4e"
      }
     },
     "metadata": {}
    },
    {
     "output_type": "display_data",
     "data": {
      "text/plain": "HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…",
      "application/vnd.jupyter.widget-view+json": {
       "version_major": 2,
       "version_minor": 0,
       "model_id": "fa20150dc3d24a59955ef14e934695a6"
      }
     },
     "metadata": {}
    },
    {
     "output_type": "display_data",
     "data": {
      "text/plain": "HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…",
      "application/vnd.jupyter.widget-view+json": {
       "version_major": 2,
       "version_minor": 0,
       "model_id": "df82de7fcf8d4a05b2ce3ed54bd1c57c"
      }
     },
     "metadata": {}
    },
    {
     "output_type": "display_data",
     "data": {
      "text/plain": "HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…",
      "application/vnd.jupyter.widget-view+json": {
       "version_major": 2,
       "version_minor": 0,
       "model_id": "371458dfb9fa4dd1b5bfc75d462501fe"
      }
     },
     "metadata": {}
    },
    {
     "output_type": "display_data",
     "data": {
      "text/plain": "HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…",
      "application/vnd.jupyter.widget-view+json": {
       "version_major": 2,
       "version_minor": 0,
       "model_id": "4623e5dd2ecd44388f7d1820071cd817"
      }
     },
     "metadata": {}
    },
    {
     "output_type": "display_data",
     "data": {
      "text/plain": "HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…",
      "application/vnd.jupyter.widget-view+json": {
       "version_major": 2,
       "version_minor": 0,
       "model_id": "f04e36fd9acb415792d27339196a918a"
      }
     },
     "metadata": {}
    },
    {
     "output_type": "display_data",
     "data": {
      "text/plain": "HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…",
      "application/vnd.jupyter.widget-view+json": {
       "version_major": 2,
       "version_minor": 0,
       "model_id": "0bb5694a5f5f458ba638dd59eaef4752"
      }
     },
     "metadata": {}
    },
    {
     "output_type": "display_data",
     "data": {
      "text/plain": "HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…",
      "application/vnd.jupyter.widget-view+json": {
       "version_major": 2,
       "version_minor": 0,
       "model_id": "750f555da8934b15ac4baabb913097c2"
      }
     },
     "metadata": {}
    },
    {
     "output_type": "display_data",
     "data": {
      "text/plain": "HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…",
      "application/vnd.jupyter.widget-view+json": {
       "version_major": 2,
       "version_minor": 0,
       "model_id": "0a3edcf1fd00465f9ad8681da796aabb"
      }
     },
     "metadata": {}
    },
    {
     "output_type": "display_data",
     "data": {
      "text/plain": "HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…",
      "application/vnd.jupyter.widget-view+json": {
       "version_major": 2,
       "version_minor": 0,
       "model_id": "c4c9e29828e645938d25b1edeb2f2693"
      }
     },
     "metadata": {}
    },
    {
     "output_type": "display_data",
     "data": {
      "text/plain": "HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…",
      "application/vnd.jupyter.widget-view+json": {
       "version_major": 2,
       "version_minor": 0,
       "model_id": "ac6b3cfd4253433f880fdca72ffd1592"
      }
     },
     "metadata": {}
    },
    {
     "output_type": "display_data",
     "data": {
      "text/plain": "HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…",
      "application/vnd.jupyter.widget-view+json": {
       "version_major": 2,
       "version_minor": 0,
       "model_id": "ca8179e9ff2f459db02d2c86167d76f0"
      }
     },
     "metadata": {}
    },
    {
     "output_type": "display_data",
     "data": {
      "text/plain": "HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…",
      "application/vnd.jupyter.widget-view+json": {
       "version_major": 2,
       "version_minor": 0,
       "model_id": "66a79945a66443c38320a93bc5090e77"
      }
     },
     "metadata": {}
    },
    {
     "output_type": "display_data",
     "data": {
      "text/plain": "HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…",
      "application/vnd.jupyter.widget-view+json": {
       "version_major": 2,
       "version_minor": 0,
       "model_id": "6a98b2bd3a624d01a4b6253d5cdaaf9a"
      }
     },
     "metadata": {}
    },
    {
     "output_type": "display_data",
     "data": {
      "text/plain": "HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…",
      "application/vnd.jupyter.widget-view+json": {
       "version_major": 2,
       "version_minor": 0,
       "model_id": "2325269f85d2410e84222481494df247"
      }
     },
     "metadata": {}
    },
    {
     "output_type": "display_data",
     "data": {
      "text/plain": "HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…",
      "application/vnd.jupyter.widget-view+json": {
       "version_major": 2,
       "version_minor": 0,
       "model_id": "34abccfdfa5d4f0689045e106b49c301"
      }
     },
     "metadata": {}
    },
    {
     "output_type": "display_data",
     "data": {
      "text/plain": "HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…",
      "application/vnd.jupyter.widget-view+json": {
       "version_major": 2,
       "version_minor": 0,
       "model_id": "dbe666c94829462c93b220646ec814f1"
      }
     },
     "metadata": {}
    },
    {
     "output_type": "display_data",
     "data": {
      "text/plain": "HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…",
      "application/vnd.jupyter.widget-view+json": {
       "version_major": 2,
       "version_minor": 0,
       "model_id": "afe7b029760b4a91bdb80a90eba4098d"
      }
     },
     "metadata": {}
    },
    {
     "output_type": "display_data",
     "data": {
      "text/plain": "HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…",
      "application/vnd.jupyter.widget-view+json": {
       "version_major": 2,
       "version_minor": 0,
       "model_id": "0a553013bb18417aa8a5015685b5f7d4"
      }
     },
     "metadata": {}
    },
    {
     "output_type": "display_data",
     "data": {
      "text/plain": "HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…",
      "application/vnd.jupyter.widget-view+json": {
       "version_major": 2,
       "version_minor": 0,
       "model_id": "bb48652cafa649c1a585dce303e96615"
      }
     },
     "metadata": {}
    },
    {
     "output_type": "display_data",
     "data": {
      "text/plain": "HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…",
      "application/vnd.jupyter.widget-view+json": {
       "version_major": 2,
       "version_minor": 0,
       "model_id": "b5e1d26314b245c48c20bb2ae9c7ec8b"
      }
     },
     "metadata": {}
    },
    {
     "output_type": "display_data",
     "data": {
      "text/plain": "HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…",
      "application/vnd.jupyter.widget-view+json": {
       "version_major": 2,
       "version_minor": 0,
       "model_id": "8b621c34886344af8c8bae8b4044e509"
      }
     },
     "metadata": {}
    },
    {
     "output_type": "display_data",
     "data": {
      "text/plain": "HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…",
      "application/vnd.jupyter.widget-view+json": {
       "version_major": 2,
       "version_minor": 0,
       "model_id": "a82a055f3d98454587856f319c29e5a0"
      }
     },
     "metadata": {}
    },
    {
     "output_type": "display_data",
     "data": {
      "text/plain": "HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…",
      "application/vnd.jupyter.widget-view+json": {
       "version_major": 2,
       "version_minor": 0,
       "model_id": "1f50d7257dd0466f803e1b0c990ecafd"
      }
     },
     "metadata": {}
    },
    {
     "output_type": "display_data",
     "data": {
      "text/plain": "HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…",
      "application/vnd.jupyter.widget-view+json": {
       "version_major": 2,
       "version_minor": 0,
       "model_id": "6daccf804a204faa8cd259dc73adecf5"
      }
     },
     "metadata": {}
    },
    {
     "output_type": "display_data",
     "data": {
      "text/plain": "HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…",
      "application/vnd.jupyter.widget-view+json": {
       "version_major": 2,
       "version_minor": 0,
       "model_id": "f2ca0ca62e18463a9145d2c8a884d2ee"
      }
     },
     "metadata": {}
    },
    {
     "output_type": "display_data",
     "data": {
      "text/plain": "HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…",
      "application/vnd.jupyter.widget-view+json": {
       "version_major": 2,
       "version_minor": 0,
       "model_id": "b2cee084536f4c53aba96a423d0a733a"
      }
     },
     "metadata": {}
    },
    {
     "output_type": "display_data",
     "data": {
      "text/plain": "HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…",
      "application/vnd.jupyter.widget-view+json": {
       "version_major": 2,
       "version_minor": 0,
       "model_id": "ee7d60055f544b4eabcf3f00dcb6724e"
      }
     },
     "metadata": {}
    },
    {
     "output_type": "display_data",
     "data": {
      "text/plain": "HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…",
      "application/vnd.jupyter.widget-view+json": {
       "version_major": 2,
       "version_minor": 0,
       "model_id": "c008bf1ecb134bf4a1fa73f0707ed3b3"
      }
     },
     "metadata": {}
    },
    {
     "output_type": "display_data",
     "data": {
      "text/plain": "HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…",
      "application/vnd.jupyter.widget-view+json": {
       "version_major": 2,
       "version_minor": 0,
       "model_id": "799b5cfe87754e76a62ce58c768ded9d"
      }
     },
     "metadata": {}
    },
    {
     "output_type": "display_data",
     "data": {
      "text/plain": "HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…",
      "application/vnd.jupyter.widget-view+json": {
       "version_major": 2,
       "version_minor": 0,
       "model_id": "8e83538391fd4ee1b4d83b36161feb10"
      }
     },
     "metadata": {}
    },
    {
     "output_type": "display_data",
     "data": {
      "text/plain": "HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…",
      "application/vnd.jupyter.widget-view+json": {
       "version_major": 2,
       "version_minor": 0,
       "model_id": "a127f68b040e40b487071975af8ff903"
      }
     },
     "metadata": {}
    },
    {
     "output_type": "display_data",
     "data": {
      "text/plain": "HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…",
      "application/vnd.jupyter.widget-view+json": {
       "version_major": 2,
       "version_minor": 0,
       "model_id": "7a9dce177c0a40b49e705db296a12c1c"
      }
     },
     "metadata": {}
    },
    {
     "output_type": "display_data",
     "data": {
      "text/plain": "HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…",
      "application/vnd.jupyter.widget-view+json": {
       "version_major": 2,
       "version_minor": 0,
       "model_id": "387c7f2ff5be47f9846a115a5f1e279b"
      }
     },
     "metadata": {}
    },
    {
     "output_type": "display_data",
     "data": {
      "text/plain": "HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…",
      "application/vnd.jupyter.widget-view+json": {
       "version_major": 2,
       "version_minor": 0,
       "model_id": "149c30f74a4f428d8511c18fcd48cfd0"
      }
     },
     "metadata": {}
    },
    {
     "output_type": "display_data",
     "data": {
      "text/plain": "HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…",
      "application/vnd.jupyter.widget-view+json": {
       "version_major": 2,
       "version_minor": 0,
       "model_id": "3ed8a1d0a79e42948da62ae22a494f2f"
      }
     },
     "metadata": {}
    },
    {
     "output_type": "display_data",
     "data": {
      "text/plain": "HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…",
      "application/vnd.jupyter.widget-view+json": {
       "version_major": 2,
       "version_minor": 0,
       "model_id": "6b8a736869c543599e80173f377b2b8b"
      }
     },
     "metadata": {}
    },
    {
     "output_type": "display_data",
     "data": {
      "text/plain": "HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…",
      "application/vnd.jupyter.widget-view+json": {
       "version_major": 2,
       "version_minor": 0,
       "model_id": "65c6f268799b4c55a2c67fb05e58bce1"
      }
     },
     "metadata": {}
    },
    {
     "output_type": "display_data",
     "data": {
      "text/plain": "HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…",
      "application/vnd.jupyter.widget-view+json": {
       "version_major": 2,
       "version_minor": 0,
       "model_id": "08574e6d357249a1862ffa25b7182213"
      }
     },
     "metadata": {}
    },
    {
     "output_type": "display_data",
     "data": {
      "text/plain": "HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…",
      "application/vnd.jupyter.widget-view+json": {
       "version_major": 2,
       "version_minor": 0,
       "model_id": "2445d3b8abbb43bf892494aea1bd2fed"
      }
     },
     "metadata": {}
    },
    {
     "output_type": "display_data",
     "data": {
      "text/plain": "HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…",
      "application/vnd.jupyter.widget-view+json": {
       "version_major": 2,
       "version_minor": 0,
       "model_id": "a24329be38c348c2afca4c7baa4f2402"
      }
     },
     "metadata": {}
    },
    {
     "output_type": "display_data",
     "data": {
      "text/plain": "HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…",
      "application/vnd.jupyter.widget-view+json": {
       "version_major": 2,
       "version_minor": 0,
       "model_id": "994dfad01adc431aa0ec523d12ee96e1"
      }
     },
     "metadata": {}
    },
    {
     "output_type": "display_data",
     "data": {
      "text/plain": "HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…",
      "application/vnd.jupyter.widget-view+json": {
       "version_major": 2,
       "version_minor": 0,
       "model_id": "e79d72ceb749455ba2c1b00235141907"
      }
     },
     "metadata": {}
    },
    {
     "output_type": "display_data",
     "data": {
      "text/plain": "HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…",
      "application/vnd.jupyter.widget-view+json": {
       "version_major": 2,
       "version_minor": 0,
       "model_id": "4d6e223c372f4a8785bcd3f4f7c2624b"
      }
     },
     "metadata": {}
    },
    {
     "output_type": "display_data",
     "data": {
      "text/plain": "HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…",
      "application/vnd.jupyter.widget-view+json": {
       "version_major": 2,
       "version_minor": 0,
       "model_id": "9ed91337edc6420983c4e3e130d20121"
      }
     },
     "metadata": {}
    },
    {
     "output_type": "display_data",
     "data": {
      "text/plain": "HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…",
      "application/vnd.jupyter.widget-view+json": {
       "version_major": 2,
       "version_minor": 0,
       "model_id": "c0ebaaa1d2934cbd8b28ec0299983f1b"
      }
     },
     "metadata": {}
    },
    {
     "output_type": "display_data",
     "data": {
      "text/plain": "HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…",
      "application/vnd.jupyter.widget-view+json": {
       "version_major": 2,
       "version_minor": 0,
       "model_id": "93c5dddf544d4d62b7898a9ced15aac1"
      }
     },
     "metadata": {}
    },
    {
     "output_type": "display_data",
     "data": {
      "text/plain": "HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…",
      "application/vnd.jupyter.widget-view+json": {
       "version_major": 2,
       "version_minor": 0,
       "model_id": "a6f8ac9dfecc48aa966bfe8894dcc68c"
      }
     },
     "metadata": {}
    },
    {
     "output_type": "display_data",
     "data": {
      "text/plain": "HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…",
      "application/vnd.jupyter.widget-view+json": {
       "version_major": 2,
       "version_minor": 0,
       "model_id": "92567e4dd98847b49d9b54b693fb0336"
      }
     },
     "metadata": {}
    },
    {
     "output_type": "display_data",
     "data": {
      "text/plain": "HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…",
      "application/vnd.jupyter.widget-view+json": {
       "version_major": 2,
       "version_minor": 0,
       "model_id": "f6f9972f85404fdfb12203e68a06c88a"
      }
     },
     "metadata": {}
    },
    {
     "output_type": "display_data",
     "data": {
      "text/plain": "HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…",
      "application/vnd.jupyter.widget-view+json": {
       "version_major": 2,
       "version_minor": 0,
       "model_id": "9ca48c8e01da4aec83b63e931ddaeb68"
      }
     },
     "metadata": {}
    },
    {
     "output_type": "display_data",
     "data": {
      "text/plain": "HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…",
      "application/vnd.jupyter.widget-view+json": {
       "version_major": 2,
       "version_minor": 0,
       "model_id": "7e2e60fc4fa64e46b8085ee7f3f5a81e"
      }
     },
     "metadata": {}
    },
    {
     "output_type": "display_data",
     "data": {
      "text/plain": "HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…",
      "application/vnd.jupyter.widget-view+json": {
       "version_major": 2,
       "version_minor": 0,
       "model_id": "23d4c968eb6443578d61145ac163e74d"
      }
     },
     "metadata": {}
    },
    {
     "output_type": "display_data",
     "data": {
      "text/plain": "HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…",
      "application/vnd.jupyter.widget-view+json": {
       "version_major": 2,
       "version_minor": 0,
       "model_id": "32608736afec43c0b268eb61a51afa35"
      }
     },
     "metadata": {}
    },
    {
     "output_type": "display_data",
     "data": {
      "text/plain": "HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…",
      "application/vnd.jupyter.widget-view+json": {
       "version_major": 2,
       "version_minor": 0,
       "model_id": "c9e6d8f3ff74481c83d5acda6afff7c5"
      }
     },
     "metadata": {}
    },
    {
     "output_type": "display_data",
     "data": {
      "text/plain": "HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…",
      "application/vnd.jupyter.widget-view+json": {
       "version_major": 2,
       "version_minor": 0,
       "model_id": "f694335d0fd144c7a8db78fad3bb6186"
      }
     },
     "metadata": {}
    },
    {
     "output_type": "display_data",
     "data": {
      "text/plain": "HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…",
      "application/vnd.jupyter.widget-view+json": {
       "version_major": 2,
       "version_minor": 0,
       "model_id": "c8ad69e83a17429e902ed6928ff81d6f"
      }
     },
     "metadata": {}
    },
    {
     "output_type": "display_data",
     "data": {
      "text/plain": "HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…",
      "application/vnd.jupyter.widget-view+json": {
       "version_major": 2,
       "version_minor": 0,
       "model_id": "a9f1571d86e64178a98f282fd8162d30"
      }
     },
     "metadata": {}
    },
    {
     "output_type": "display_data",
     "data": {
      "text/plain": "HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…",
      "application/vnd.jupyter.widget-view+json": {
       "version_major": 2,
       "version_minor": 0,
       "model_id": "c873fe513d4d4848974988089651e475"
      }
     },
     "metadata": {}
    },
    {
     "output_type": "display_data",
     "data": {
      "text/plain": "HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…",
      "application/vnd.jupyter.widget-view+json": {
       "version_major": 2,
       "version_minor": 0,
       "model_id": "099bb893db45491abf1df17da3297b9a"
      }
     },
     "metadata": {}
    },
    {
     "output_type": "display_data",
     "data": {
      "text/plain": "HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…",
      "application/vnd.jupyter.widget-view+json": {
       "version_major": 2,
       "version_minor": 0,
       "model_id": "649c2a57480244fbb8472f472df14059"
      }
     },
     "metadata": {}
    },
    {
     "output_type": "display_data",
     "data": {
      "text/plain": "HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…",
      "application/vnd.jupyter.widget-view+json": {
       "version_major": 2,
       "version_minor": 0,
       "model_id": "b39fee41a5de4c8b856c6fa44a587f00"
      }
     },
     "metadata": {}
    },
    {
     "output_type": "display_data",
     "data": {
      "text/plain": "HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…",
      "application/vnd.jupyter.widget-view+json": {
       "version_major": 2,
       "version_minor": 0,
       "model_id": "d5cb5e19ecba461cac413bea03c07dd0"
      }
     },
     "metadata": {}
    },
    {
     "output_type": "display_data",
     "data": {
      "text/plain": "HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…",
      "application/vnd.jupyter.widget-view+json": {
       "version_major": 2,
       "version_minor": 0,
       "model_id": "de2cb7eca0144deca0a76c53e0a38f07"
      }
     },
     "metadata": {}
    },
    {
     "output_type": "display_data",
     "data": {
      "text/plain": "HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…",
      "application/vnd.jupyter.widget-view+json": {
       "version_major": 2,
       "version_minor": 0,
       "model_id": "d37e6ed369ce4badb451a1d26da975fb"
      }
     },
     "metadata": {}
    },
    {
     "output_type": "display_data",
     "data": {
      "text/plain": "HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…",
      "application/vnd.jupyter.widget-view+json": {
       "version_major": 2,
       "version_minor": 0,
       "model_id": "71c980f2613147778f1b0da703f8f034"
      }
     },
     "metadata": {}
    },
    {
     "output_type": "display_data",
     "data": {
      "text/plain": "HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…",
      "application/vnd.jupyter.widget-view+json": {
       "version_major": 2,
       "version_minor": 0,
       "model_id": "a4c915857b1c42e4a496bd7ec0a86fd6"
      }
     },
     "metadata": {}
    },
    {
     "output_type": "display_data",
     "data": {
      "text/plain": "HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…",
      "application/vnd.jupyter.widget-view+json": {
       "version_major": 2,
       "version_minor": 0,
       "model_id": "43bc9249804947dc9f53aee8ea9af0b9"
      }
     },
     "metadata": {}
    },
    {
     "output_type": "display_data",
     "data": {
      "text/plain": "HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…",
      "application/vnd.jupyter.widget-view+json": {
       "version_major": 2,
       "version_minor": 0,
       "model_id": "28aaf1a1f449412893b2f524e75f464f"
      }
     },
     "metadata": {}
    },
    {
     "output_type": "display_data",
     "data": {
      "text/plain": "HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…",
      "application/vnd.jupyter.widget-view+json": {
       "version_major": 2,
       "version_minor": 0,
       "model_id": "454df1f9311747268f62e86954816480"
      }
     },
     "metadata": {}
    },
    {
     "output_type": "display_data",
     "data": {
      "text/plain": "HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…",
      "application/vnd.jupyter.widget-view+json": {
       "version_major": 2,
       "version_minor": 0,
       "model_id": "ad70dfe5842d48cf96ec64a9211fcc98"
      }
     },
     "metadata": {}
    },
    {
     "output_type": "display_data",
     "data": {
      "text/plain": "HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…",
      "application/vnd.jupyter.widget-view+json": {
       "version_major": 2,
       "version_minor": 0,
       "model_id": "e2c4ab77737e4f85bb85a738e7d6d628"
      }
     },
     "metadata": {}
    },
    {
     "output_type": "display_data",
     "data": {
      "text/plain": "HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…",
      "application/vnd.jupyter.widget-view+json": {
       "version_major": 2,
       "version_minor": 0,
       "model_id": "1283a24ff8c54595a20181607b6e1f55"
      }
     },
     "metadata": {}
    },
    {
     "output_type": "display_data",
     "data": {
      "text/plain": "HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…",
      "application/vnd.jupyter.widget-view+json": {
       "version_major": 2,
       "version_minor": 0,
       "model_id": "07cce7d861aa490dac51340d995df07f"
      }
     },
     "metadata": {}
    },
    {
     "output_type": "display_data",
     "data": {
      "text/plain": "HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…",
      "application/vnd.jupyter.widget-view+json": {
       "version_major": 2,
       "version_minor": 0,
       "model_id": "1427f323880140a797e80e6e7bc37678"
      }
     },
     "metadata": {}
    },
    {
     "output_type": "display_data",
     "data": {
      "text/plain": "HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…",
      "application/vnd.jupyter.widget-view+json": {
       "version_major": 2,
       "version_minor": 0,
       "model_id": "65110a3c4b5b4d778f72da56b8bf2d42"
      }
     },
     "metadata": {}
    },
    {
     "output_type": "display_data",
     "data": {
      "text/plain": "HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…",
      "application/vnd.jupyter.widget-view+json": {
       "version_major": 2,
       "version_minor": 0,
       "model_id": "52d3adeba78446b783632215242b001f"
      }
     },
     "metadata": {}
    },
    {
     "output_type": "display_data",
     "data": {
      "text/plain": "HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…",
      "application/vnd.jupyter.widget-view+json": {
       "version_major": 2,
       "version_minor": 0,
       "model_id": "1cd04be850f44389add1b51573776a43"
      }
     },
     "metadata": {}
    },
    {
     "output_type": "display_data",
     "data": {
      "text/plain": "HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…",
      "application/vnd.jupyter.widget-view+json": {
       "version_major": 2,
       "version_minor": 0,
       "model_id": "d3325f9ca79341e4b36190f7679e3662"
      }
     },
     "metadata": {}
    },
    {
     "output_type": "display_data",
     "data": {
      "text/plain": "HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…",
      "application/vnd.jupyter.widget-view+json": {
       "version_major": 2,
       "version_minor": 0,
       "model_id": "4e2f4c630df54413b8fee220f8307116"
      }
     },
     "metadata": {}
    },
    {
     "output_type": "display_data",
     "data": {
      "text/plain": "HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…",
      "application/vnd.jupyter.widget-view+json": {
       "version_major": 2,
       "version_minor": 0,
       "model_id": "0f007931c64c451b8293e29e84aeb66c"
      }
     },
     "metadata": {}
    },
    {
     "output_type": "display_data",
     "data": {
      "text/plain": "HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…",
      "application/vnd.jupyter.widget-view+json": {
       "version_major": 2,
       "version_minor": 0,
       "model_id": "93e348fe5daf43789b38ba61a06af867"
      }
     },
     "metadata": {}
    },
    {
     "output_type": "display_data",
     "data": {
      "text/plain": "HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…",
      "application/vnd.jupyter.widget-view+json": {
       "version_major": 2,
       "version_minor": 0,
       "model_id": "1dc5c9b4ed904ed7916c37d8fb023722"
      }
     },
     "metadata": {}
    },
    {
     "output_type": "display_data",
     "data": {
      "text/plain": "HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…",
      "application/vnd.jupyter.widget-view+json": {
       "version_major": 2,
       "version_minor": 0,
       "model_id": "2519fd71a1af480ebb3081d6a05a5334"
      }
     },
     "metadata": {}
    },
    {
     "output_type": "display_data",
     "data": {
      "text/plain": "HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…",
      "application/vnd.jupyter.widget-view+json": {
       "version_major": 2,
       "version_minor": 0,
       "model_id": "d7b484ad06834b7b9a28eed24be9acde"
      }
     },
     "metadata": {}
    },
    {
     "output_type": "display_data",
     "data": {
      "text/plain": "HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…",
      "application/vnd.jupyter.widget-view+json": {
       "version_major": 2,
       "version_minor": 0,
       "model_id": "6093f2b05ed243968749d150f469fe45"
      }
     },
     "metadata": {}
    },
    {
     "output_type": "display_data",
     "data": {
      "text/plain": "HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…",
      "application/vnd.jupyter.widget-view+json": {
       "version_major": 2,
       "version_minor": 0,
       "model_id": "962e5aedcec640a7918f0be16cab354e"
      }
     },
     "metadata": {}
    },
    {
     "output_type": "display_data",
     "data": {
      "text/plain": "HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…",
      "application/vnd.jupyter.widget-view+json": {
       "version_major": 2,
       "version_minor": 0,
       "model_id": "89a08ba6ba5542c48344ae897c3bad3a"
      }
     },
     "metadata": {}
    },
    {
     "output_type": "display_data",
     "data": {
      "text/plain": "HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…",
      "application/vnd.jupyter.widget-view+json": {
       "version_major": 2,
       "version_minor": 0,
       "model_id": "5766cc9204e340dba07ae30508ee36da"
      }
     },
     "metadata": {}
    },
    {
     "output_type": "display_data",
     "data": {
      "text/plain": "HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…",
      "application/vnd.jupyter.widget-view+json": {
       "version_major": 2,
       "version_minor": 0,
       "model_id": "fbe985d8fd354632b392af00137f23eb"
      }
     },
     "metadata": {}
    },
    {
     "output_type": "display_data",
     "data": {
      "text/plain": "HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…",
      "application/vnd.jupyter.widget-view+json": {
       "version_major": 2,
       "version_minor": 0,
       "model_id": "d11d9732bff743cabf070075b224a62e"
      }
     },
     "metadata": {}
    },
    {
     "output_type": "display_data",
     "data": {
      "text/plain": "HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…",
      "application/vnd.jupyter.widget-view+json": {
       "version_major": 2,
       "version_minor": 0,
       "model_id": "bc2deff66bbf46e18052db6059280461"
      }
     },
     "metadata": {}
    },
    {
     "output_type": "display_data",
     "data": {
      "text/plain": "HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…",
      "application/vnd.jupyter.widget-view+json": {
       "version_major": 2,
       "version_minor": 0,
       "model_id": "6f663db045994b97930d88361b9dbe48"
      }
     },
     "metadata": {}
    },
    {
     "output_type": "display_data",
     "data": {
      "text/plain": "HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…",
      "application/vnd.jupyter.widget-view+json": {
       "version_major": 2,
       "version_minor": 0,
       "model_id": "9fe3fda654db43049c0aa363b71cf6a4"
      }
     },
     "metadata": {}
    },
    {
     "output_type": "display_data",
     "data": {
      "text/plain": "HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…",
      "application/vnd.jupyter.widget-view+json": {
       "version_major": 2,
       "version_minor": 0,
       "model_id": "5aaa4d9b7692472784221b4184fc4495"
      }
     },
     "metadata": {}
    },
    {
     "output_type": "display_data",
     "data": {
      "text/plain": "HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…",
      "application/vnd.jupyter.widget-view+json": {
       "version_major": 2,
       "version_minor": 0,
       "model_id": "93f7c8f241b44d6aa757227d83a9c5eb"
      }
     },
     "metadata": {}
    },
    {
     "output_type": "stream",
     "name": "stdout",
     "text": "\n"
    },
    {
     "output_type": "execute_result",
     "data": {
      "text/plain": "1"
     },
     "metadata": {},
     "execution_count": 34
    }
   ],
   "source": [
    "logic_model = LogicRegModel()\n",
    "\n",
    "# most basic trainer, uses good defaults (1 gpu)\n",
    "trainer = pl.Trainer(max_epochs=100, num_sanity_val_steps=0)\n",
    "trainer.fit(logic_model)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 35,
   "metadata": {
    "tags": []
   },
   "outputs": [
    {
     "output_type": "stream",
     "name": "stdout",
     "text": "Parameter containing:\ntensor([[0.6615, 0.8123]], requires_grad=True)\nParameter containing:\ntensor([-4.4929], requires_grad=True)\n"
    }
   ],
   "source": [
    "# 打印所有参数\n",
    "for i in logic_model.parameters():\n",
    "    print(i)\n",
    "\n",
    "true_weight = torch.tensor([2, -3.4], dtype=torch.float32).view(2, 1)\n",
    "true_bias = torch.tensor(4.2, dtype=torch.float32)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 30,
   "metadata": {
    "tags": []
   },
   "outputs": [
    {
     "output_type": "display_data",
     "data": {
      "text/plain": "HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Testing', layout=Layout(flex='2'), max=…",
      "application/vnd.jupyter.widget-view+json": {
       "version_major": 2,
       "version_minor": 0,
       "model_id": "ec7cd3cae3bd47e9bcad417680bd9575"
      }
     },
     "metadata": {}
    },
    {
     "output_type": "stream",
     "name": "stdout",
     "text": "--------------------------------------------------------------------------------\nTEST RESULTS\n{'test_loss': tensor(0.1366)}\n--------------------------------------------------------------------------------\n\n"
    },
    {
     "output_type": "execute_result",
     "data": {
      "text/plain": "{'test_loss': 0.1365588903427124}"
     },
     "metadata": {},
     "execution_count": 30
    }
   ],
   "source": [
    "trainer.test()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ]
}